#!/usr/bin/env python
"""Run evaluation on a single image, using an NBDT"""

from nbdt.model import SoftNBDT, HardNBDT
from pytorchcv.models.wrn_cifar import wrn28_10_cifar10
from torchvision import transforms
from nbdt.utils import DATASET_TO_CLASSES, load_image_from_path
from nbdt.thirdparty.wn import maybe_install_wordnet
import sys

maybe_install_wordnet()

assert len(sys.argv) > 1, "Need to pass image URL or image path as argument"

# load pretrained NBDT
model = wrn28_10_cifar10()
model = SoftNBDT(
    pretrained=True, dataset="CIFAR10", arch="wrn28_10_cifar10", model=model
)

# load + transform image
im = load_image_from_path(sys.argv[1])
transform = transforms.Compose(
    [
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)
x = transform(im)[None]

# run inference
outputs, decisions = model.forward_with_decisions(
    x
)  # use `model(x)` to obtain just logits
_, predicted = outputs.max(1)
cls = DATASET_TO_CLASSES["CIFAR10"][predicted[0]]
print(
    "Prediction:",
    cls,
    "// Decisions:",
    ", ".join(
        [
            "{} (Confidence: {:.2f}%)".format(info["name"], (1 - info["entropy"]) * 100)
            for info in decisions[0]
        ][1:]
    ),
)  # [1:] to skip the root
